{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Plotting Tutorial\n",
    "\n",
    "The Hail plot module allows for easy plotting of data. This notebook contains examples of how to use the plotting functions in this module, many of which can also be found in the first tutorial."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import hail as hl\n",
    "\n",
    "hl.init()\n",
    "\n",
    "from bokeh.io import show\n",
    "from bokeh.layouts import gridplot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "hl.utils.get_1kg('data/')\n",
    "mt = hl.read_matrix_table('data/1kg.mt')\n",
    "table = hl.import_table('data/1kg_annotations.txt', impute=True).key_by('Sample')\n",
    "mt = mt.annotate_cols(**table[mt.s])\n",
    "mt = hl.sample_qc(mt)\n",
    "\n",
    "mt.describe()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Histogram\n",
    "\n",
    "The `histogram()` method takes as an argument an aggregated hist expression, as well as optional arguments for the legend and title of the plot."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dp_hist = mt.aggregate_entries(hl.expr.aggregators.hist(mt.DP, 0, 30, 30))\n",
    "p = hl.plot.histogram(dp_hist, legend='DP', title='DP Histogram')\n",
    "show(p)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This method, like all Hail plotting methods, also allows us to pass in fields of our data set directly. Choosing not to specify the `range` and `bins` arguments would result in a range being computed based on the largest and smallest values in the dataset and a default bins value of 50."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "p = hl.plot.histogram(mt.DP, range=(0, 30), bins=30)\n",
    "show(p)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Cumulative Histogram\n",
    "\n",
    "The `cumulative_histogram()` method works in a similar way to `histogram()`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "p = hl.plot.cumulative_histogram(mt.DP, range=(0, 30), bins=30)\n",
    "show(p)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Scatter\n",
    "\n",
    "The `scatter()` method can also take in either Python types or Hail fields as arguments for x and y."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "p = hl.plot.scatter(mt.sample_qc.dp_stats.mean, mt.sample_qc.call_rate, xlabel='Mean DP', ylabel='Call Rate')\n",
    "show(p)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can also pass in a Hail field as a `label` argument, which determines how to color the data points."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mt = mt.filter_cols((mt.sample_qc.dp_stats.mean >= 4) & (mt.sample_qc.call_rate >= 0.97))\n",
    "ab = mt.AD[1] / hl.sum(mt.AD)\n",
    "filter_condition_ab = (\n",
    "    (mt.GT.is_hom_ref() & (ab <= 0.1))\n",
    "    | (mt.GT.is_het() & (ab >= 0.25) & (ab <= 0.75))\n",
    "    | (mt.GT.is_hom_var() & (ab >= 0.9))\n",
    ")\n",
    "mt = mt.filter_entries(filter_condition_ab)\n",
    "mt = hl.variant_qc(mt).cache()\n",
    "common_mt = mt.filter_rows(mt.variant_qc.AF[1] > 0.01)\n",
    "gwas = hl.linear_regression_rows(y=common_mt.CaffeineConsumption, x=common_mt.GT.n_alt_alleles(), covariates=[1.0])\n",
    "pca_eigenvalues, pca_scores, _ = hl.hwe_normalized_pca(common_mt.GT)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "p = hl.plot.scatter(\n",
    "    pca_scores.scores[0],\n",
    "    pca_scores.scores[1],\n",
    "    label=common_mt.cols()[pca_scores.s].SuperPopulation,\n",
    "    title='PCA',\n",
    "    xlabel='PC1',\n",
    "    ylabel='PC2',\n",
    "    n_divisions=None,\n",
    ")\n",
    "show(p)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Hail's downsample aggregator is incorporated into the `scatter()`, `qq()`,\n",
    "`join_plot` and `manhattan()` functions. The `n_divisions` parameter controls\n",
    "the factor by which values are downsampled. Using `n_divisions=None` tells the\n",
    "plot function to collect all values."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "p2 = hl.plot.scatter(\n",
    "    pca_scores.scores[0],\n",
    "    pca_scores.scores[1],\n",
    "    label=common_mt.cols()[pca_scores.s].SuperPopulation,\n",
    "    title='PCA (downsampled)',\n",
    "    xlabel='PC1',\n",
    "    ylabel='PC2',\n",
    "    n_divisions=50,\n",
    ")\n",
    "show(gridplot([p, p2], ncols=2, width=400, height=400))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2-D histogram\n",
    "\n",
    "For visualizing relationships between variables in large datasets (where scatter plots may be less informative since they highlight outliers), the `histogram_2d()` function will create a heatmap with the number of observations in each section of a 2-d grid based on two variables."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "p = hl.plot.histogram2d(pca_scores.scores[0], pca_scores.scores[1])\n",
    "show(p)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Q-Q (Quantile-Quantile)\n",
    "\n",
    "The `qq()` function requires either a Python type or a Hail field containing p-values to be plotted. This function also allows for downsampling."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "p = hl.plot.qq(gwas.p_value, n_divisions=None)\n",
    "p2 = hl.plot.qq(gwas.p_value, n_divisions=75)\n",
    "\n",
    "show(gridplot([p, p2], ncols=2, width=400, height=400))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Manhattan\n",
    "\n",
    "The `manhattan()` function requires a Hail field containing p-values."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "p = hl.plot.manhattan(gwas.p_value)\n",
    "show(p)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can also pass in a dictionary of fields that we would like to show up as we hover over a data point, and choose not to downsample if the dataset is relatively small."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "hover_fields = dict([('alleles', gwas.alleles)])\n",
    "p = hl.plot.manhattan(gwas.p_value, hover_fields=hover_fields, n_divisions=None)\n",
    "show(p)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
